{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-2bf7dd5bfcf7>:72: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../../datasets/MNIST_data\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../../datasets/MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../../../datasets/MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../../datasets/MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "After 1 training step(s), loss on training batch is 4.90434.\n",
      "After 1001 training step(s), loss on training batch is 0.716846.\n",
      "After 2001 training step(s), loss on training batch is 0.685001.\n",
      "After 3001 training step(s), loss on training batch is 0.654261.\n",
      "After 4001 training step(s), loss on training batch is 0.67465.\n",
      "After 5001 training step(s), loss on training batch is 0.645965.\n",
      "After 6001 training step(s), loss on training batch is 0.64026.\n",
      "After 7001 training step(s), loss on training batch is 0.63272.\n",
      "After 8001 training step(s), loss on training batch is 0.630295.\n",
      "After 9001 training step(s), loss on training batch is 0.657804.\n",
      "After 10001 training step(s), loss on training batch is 0.630289.\n",
      "After 11001 training step(s), loss on training batch is 0.63863.\n",
      "After 12001 training step(s), loss on training batch is 0.682623.\n",
      "After 13001 training step(s), loss on training batch is 0.61866.\n",
      "After 14001 training step(s), loss on training batch is 0.641642.\n",
      "After 15001 training step(s), loss on training batch is 0.610577.\n",
      "After 16001 training step(s), loss on training batch is 0.622926.\n",
      "After 17001 training step(s), loss on training batch is 0.61515.\n",
      "After 18001 training step(s), loss on training batch is 0.609648.\n",
      "After 19001 training step(s), loss on training batch is 0.612745.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import LeNet5_inference\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "# 定义神经网络相关的参数\n",
    "BATCH_SIZE = 100\n",
    "LEARNING_RATE_BASE = 0.01\n",
    "LEARNING_RATE_DECAY = 0.99\n",
    "REGULARIZATION_RATE = 0.0001\n",
    "TRAINING_STEPS = 20000\n",
    "MOVING_AVERAGE_DECAY = 0.99\n",
    "\n",
    "# 模型保存的路径和文件名\n",
    "MODEL_SAVE_PATH=\"MNIST_model/\"\n",
    "MODEL_NAME=\"mnist_model\"\n",
    "\n",
    "\n",
    "# 与第5章的区别在于，输入数据需要更改为四维\n",
    "def train(mnist):\n",
    "    # 1. 定义输入输出（参数在inference函数中）\n",
    "    x = tf.placeholder(tf.float32, [\n",
    "            BATCH_SIZE,\n",
    "            LeNet5_inference.IMAGE_SIZE,\n",
    "            LeNet5_inference.IMAGE_SIZE,\n",
    "            LeNet5_inference.NUM_CHANNELS],\n",
    "        name='x-input')\n",
    "    y_ = tf.placeholder(tf.float32, [None, LeNet5_inference.OUTPUT_NODE], name='y-input')\n",
    "    \n",
    "    # 2. 定义前向传播、损失函数、反向传播\n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)\n",
    "    y = LeNet5_inference.inference(x, False, regularizer)\n",
    "    \n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)\n",
    "    variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))\n",
    "    \n",
    "    learning_rate = tf.train.exponential_decay(\n",
    "        LEARNING_RATE_BASE,\n",
    "        global_step,\n",
    "        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,\n",
    "        staircase=True)\n",
    "    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n",
    "    with tf.control_dependencies([train_step, variables_averages_op]):\n",
    "        train_op = tf.no_op(name='train')\n",
    "        \n",
    "    # 3. 建立会话，训练\n",
    "    saver = tf.train.Saver()    # 初始化TensorFlow持久化类\n",
    "    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)\n",
    "    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:\n",
    "        tf.global_variables_initializer().run()\n",
    "        for i in range(TRAINING_STEPS):\n",
    "            xs, ys = mnist.train.next_batch(BATCH_SIZE)\n",
    "\n",
    "            reshaped_xs = np.reshape(xs, (\n",
    "                BATCH_SIZE,\n",
    "                LeNet5_inference.IMAGE_SIZE,\n",
    "                LeNet5_inference.IMAGE_SIZE,\n",
    "                LeNet5_inference.NUM_CHANNELS))\n",
    "            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})\n",
    "\n",
    "            if i % 1000 == 0:\n",
    "                print(\"After %d training step(s), loss on training batch is %g.\" % (step, loss_value))\n",
    "                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)\n",
    "                \n",
    "                \n",
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"../../../datasets/MNIST_data\", one_hot=True)\n",
    "    train(mnist)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
